{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 👖 Autoencoders on Fashion MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own autoencoder on the fashion MNIST dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tensorflow.keras import layers, models, datasets, callbacks\n",
    "import tensorflow.keras.backend as K\n",
    "\n",
    "from notebooks.utils import display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 32\n",
    "CHANNELS = 1\n",
    "BATCH_SIZE = 100\n",
    "BUFFER_SIZE = 1000\n",
    "VALIDATION_SPLIT = 0.2\n",
    "EMBEDDING_DIM = 2\n",
    "EPOCHS = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7716fac-0010-49b0-b98e-53be2259edde",
   "metadata": {},
   "source": [
    "## 1. Prepare the data <a name=\"prepare\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebae2f0d-59fd-4796-841f-7213eae638de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the data\n",
    "\n",
    "\n",
    "def preprocess(imgs):\n",
    "    \"\"\"\n",
    "    Normalize and reshape the images\n",
    "    \"\"\"\n",
    "    imgs = imgs.astype(\"float32\") / 255.0\n",
    "    imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values=0.0)\n",
    "    imgs = np.expand_dims(imgs, -1)\n",
    "    return imgs\n",
    "\n",
    "\n",
    "x_train = preprocess(x_train)\n",
    "x_test = preprocess(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show some items of clothing from the training set\n",
    "display(x_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2. Build the autoencoder <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "086e2584-c60d-4990-89f4-2092c44e023e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encoder\n",
    "encoder_input = layers.Input(\n",
    "    shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS), name=\"encoder_input\"\n",
    ")\n",
    "x = layers.Conv2D(32, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(\n",
    "    encoder_input\n",
    ")\n",
    "x = layers.Conv2D(64, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(x)\n",
    "x = layers.Conv2D(128, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(x)\n",
    "shape_before_flattening = K.int_shape(x)[1:]  # the decoder will need this!\n",
    "\n",
    "x = layers.Flatten()(x)\n",
    "encoder_output = layers.Dense(EMBEDDING_DIM, name=\"encoder_output\")(x)\n",
    "\n",
    "encoder = models.Model(encoder_input, encoder_output)\n",
    "encoder.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c409e63-1aea-42e2-8324-c3e2a12073ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decoder\n",
    "decoder_input = layers.Input(shape=(EMBEDDING_DIM,), name=\"decoder_input\")\n",
    "x = layers.Dense(np.prod(shape_before_flattening))(decoder_input)\n",
    "x = layers.Reshape(shape_before_flattening)(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    128, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n",
    ")(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    64, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n",
    ")(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    32, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n",
    ")(x)\n",
    "decoder_output = layers.Conv2D(\n",
    "    CHANNELS,\n",
    "    (3, 3),\n",
    "    strides=1,\n",
    "    activation=\"sigmoid\",\n",
    "    padding=\"same\",\n",
    "    name=\"decoder_output\",\n",
    ")(x)\n",
    "\n",
    "decoder = models.Model(decoder_input, decoder_output)\n",
    "decoder.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34dc7c69-26a8-4c17-aa24-792f1b0a69b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Autoencoder\n",
    "autoencoder = models.Model(\n",
    "    encoder_input, decoder(encoder_output)\n",
    ")  # decoder(encoder_output)\n",
    "autoencoder.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b14665-4359-447b-be58-3fd58ba69084",
   "metadata": {},
   "source": [
    "## 3. Train the autoencoder <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b429fdad-ea9c-45a2-a556-eb950d793824",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile the autoencoder\n",
    "autoencoder.compile(optimizer=\"adam\", loss=\"binary_crossentropy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a model save checkpoint\n",
    "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
    "    filepath=\"./checkpoint\",\n",
    "    save_weights_only=False,\n",
    "    save_freq=\"epoch\",\n",
    "    monitor=\"loss\",\n",
    "    mode=\"min\",\n",
    "    save_best_only=True,\n",
    "    verbose=0,\n",
    ")\n",
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c497b7-fa40-48df-b2bf-541239cc9400",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "autoencoder.fit(\n",
    "    x_train,\n",
    "    x_train,\n",
    "    epochs=EPOCHS,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    validation_data=(x_test, x_test),\n",
    "    callbacks=[model_checkpoint_callback, tensorboard_callback],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edb847d1-c22d-4923-ba92-0ecde0f12fca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the final models\n",
    "autoencoder.save(\"./models/autoencoder\")\n",
    "encoder.save(\"./models/encoder\")\n",
    "decoder.save(\"./models/decoder\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc0f31bc-77e6-49e8-bb76-51bca124744c",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 4. Reconstruct using the autoencoder <a name=\"reconstruct\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4d83729-71a2-4494-86a5-e17830974ef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_to_predict = 5000\n",
    "example_images = x_test[:n_to_predict]\n",
    "example_labels = y_test[:n_to_predict]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c9b2a91-7cea-4595-a857-11f5ab00875e",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = autoencoder.predict(example_images)\n",
    "\n",
    "print(\"Example real clothing items\")\n",
    "display(example_images)\n",
    "print(\"Reconstructions\")\n",
    "display(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b77c88bb-ada4-4091-94e3-764f1385f1fc",
   "metadata": {},
   "source": [
    "## 5. Embed using the encoder <a name=\"encode\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e723c1c-136b-47e5-9972-ee964712d148",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encode the example images\n",
    "embeddings = encoder.predict(example_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ed4e9bd-df14-4832-a765-dfaf36d49fca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some examples of the embeddings\n",
    "print(embeddings[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb208e8-6351-49ac-a68c-679a830f13bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show the encoded points in 2D space\n",
    "figsize = 8\n",
    "\n",
    "plt.figure(figsize=(figsize, figsize))\n",
    "plt.scatter(embeddings[:, 0], embeddings[:, 1], c=\"black\", alpha=0.5, s=3)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "138a34ca-67b4-42b7-a9fa-f7ffe397df49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Colour the embeddings by their label (clothing type - see table)\n",
    "example_labels = y_test[:n_to_predict]\n",
    "\n",
    "figsize = 8\n",
    "plt.figure(figsize=(figsize, figsize))\n",
    "plt.scatter(\n",
    "    embeddings[:, 0],\n",
    "    embeddings[:, 1],\n",
    "    cmap=\"rainbow\",\n",
    "    c=example_labels,\n",
    "    alpha=0.8,\n",
    "    s=3,\n",
    ")\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0616b71-3354-419c-8ddb-f64fc29850ca",
   "metadata": {},
   "source": [
    "## 6. Generate using the decoder <a name=\"decode\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d494893-059f-42e4-825e-31c06fa3cd09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the range of the existing embeddings\n",
    "mins, maxs = np.min(embeddings, axis=0), np.max(embeddings, axis=0)\n",
    "\n",
    "# Sample some points in the latent space\n",
    "grid_width, grid_height = (6, 3)\n",
    "sample = np.random.uniform(\n",
    "    mins, maxs, size=(grid_width * grid_height, EMBEDDING_DIM)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba3b1c66-c89d-436a-b009-19f1f5a785e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decode the sampled points\n",
    "reconstructions = decoder.predict(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feea9b9d-8d3e-43f5-9ead-cd9e38367c00",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw a plot of...\n",
    "figsize = 8\n",
    "plt.figure(figsize=(figsize, figsize))\n",
    "\n",
    "# ... the original embeddings ...\n",
    "plt.scatter(embeddings[:, 0], embeddings[:, 1], c=\"black\", alpha=0.5, s=2)\n",
    "\n",
    "# ... and the newly generated points in the latent space\n",
    "plt.scatter(sample[:, 0], sample[:, 1], c=\"#00B0F0\", alpha=1, s=40)\n",
    "plt.show()\n",
    "\n",
    "# Add underneath a grid of the decoded images\n",
    "fig = plt.figure(figsize=(figsize, grid_height * 2))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "\n",
    "for i in range(grid_width * grid_height):\n",
    "    ax = fig.add_subplot(grid_height, grid_width, i + 1)\n",
    "    ax.axis(\"off\")\n",
    "    ax.text(\n",
    "        0.5,\n",
    "        -0.35,\n",
    "        str(np.round(sample[i, :], 1)),\n",
    "        fontsize=10,\n",
    "        ha=\"center\",\n",
    "        transform=ax.transAxes,\n",
    "    )\n",
    "    ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f64434a4-41c5-4225-ad31-9cf83f8797e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Colour the embeddings by their label (clothing type - see table)\n",
    "figsize = 12\n",
    "grid_size = 15\n",
    "plt.figure(figsize=(figsize, figsize))\n",
    "plt.scatter(\n",
    "    embeddings[:, 0],\n",
    "    embeddings[:, 1],\n",
    "    cmap=\"rainbow\",\n",
    "    c=example_labels,\n",
    "    alpha=0.8,\n",
    "    s=300,\n",
    ")\n",
    "plt.colorbar()\n",
    "\n",
    "x = np.linspace(min(embeddings[:, 0]), max(embeddings[:, 0]), grid_size)\n",
    "y = np.linspace(max(embeddings[:, 1]), min(embeddings[:, 1]), grid_size)\n",
    "xv, yv = np.meshgrid(x, y)\n",
    "xv = xv.flatten()\n",
    "yv = yv.flatten()\n",
    "grid = np.array(list(zip(xv, yv)))\n",
    "\n",
    "reconstructions = decoder.predict(grid)\n",
    "# plt.scatter(grid[:, 0], grid[:, 1], c=\"black\", alpha=1, s=10)\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(figsize, figsize))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "for i in range(grid_size**2):\n",
    "    ax = fig.add_subplot(grid_size, grid_size, i + 1)\n",
    "    ax.axis(\"off\")\n",
    "    ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
